# /// script
# requires-python = ">=3.12"
# dependencies = [
#     "kokoro==0.8.4",
#     "onnx==1.17.0",
#     "onnxruntime==1.20.1",
#     "sounddevice==0.5.1",
# ]
#
# ///

"""
From https://github.com/hexgrad/kokoro/blob/3f9dd88d6f739b98a86aea608e238621f5b40add/examples/export.py

mkdir checkpoints
wget https://huggingface.co/hexgrad/Kokoro-82M-v1.1-zh/resolve/main/config.json -O checkpoints/config.json
wget https://huggingface.co/hexgrad/Kokoro-82M-v1.1-zh/resolve/main/kokoro-v1_1-zh.pth -O checkpoints/kokoro-v1_1-zh.pth
uv run examples/export.py
uv run examples/export.py --config_file checkpoints/config.json --checkpoint_path checkpoints/kokoro-v1_1-zh.pth
"""

import argparse
import os

import onnx
import onnxruntime as ort
import sounddevice as sd
import torch
from kokoro import KModel, KPipeline
from kokoro.model import KModelForONNX


def export_onnx(model, output):
    onnx_file = output + "/" + "kokoro.onnx"

    input_ids = torch.randint(1, 100, (48,)).numpy()
    input_ids = torch.LongTensor([[0, *input_ids, 0]])
    style = torch.randn(1, 256)
    speed = torch.randint(1, 10, (1,)).int()

    torch.onnx.export(
        model,
        args=(input_ids, style, speed),
        f=onnx_file,
        export_params=True,
        verbose=True,
        input_names=["input_ids", "style", "speed"],
        output_names=["waveform", "duration"],
        opset_version=17,
        dynamic_axes={
            "input_ids": {1: "input_ids_len"},
            "waveform": {0: "num_samples"},
        },
        do_constant_folding=True,
    )

    print("export kokoro.onnx ok!")

    onnx_model = onnx.load(onnx_file)
    onnx.checker.check_model(onnx_model)
    print("onnx check ok!")


def load_input_ids(pipeline, text):
    if pipeline.lang_code in "ab":
        _, tokens = pipeline.g2p(text)
        for gs, ps, tks in pipeline.en_tokenize(tokens):
            if not ps:
                continue
    else:
        ps, _ = pipeline.g2p(text)

    if len(ps) > 510:
        ps = ps[:510]

    input_ids = list(
        filter(lambda i: i is not None, map(lambda p: pipeline.model.vocab.get(p), ps))
    )
    print(f"text: {text} -> phonemes: {ps} -> input_ids: {input_ids}")
    input_ids = torch.LongTensor([[0, *input_ids, 0]]).to(pipeline.model.device)
    return ps, input_ids


def load_voice(pipeline, voice, phonemes):
    pack = pipeline.load_voice(voice).to("cpu")
    return pack[len(phonemes) - 1]


def load_sample(model):
    pipeline = KPipeline(lang_code="a", model=model.kmodel, device="cpu")
    text = """
    In today's fast-paced tech world, building software applications has never been easier — thanks to AI-powered coding assistants.'
    """
    text = """
    The sky above the port was the color of television, tuned to a dead channel.
    """
    voice = "checkpoints/voices/af_heart.pt"

    pipeline = KPipeline(lang_code="z", model=model.kmodel, device="cpu")
    text = """
    2月15日晚，猫眼专业版数据显示，截至发稿，《哪吒之魔童闹海》（或称《哪吒2》）今日票房已达7.8亿元，累计票房（含预售）超过114亿元。
    """
    voice = "checkpoints/voices/zf_xiaoxiao.pt"

    phonemes, input_ids = load_input_ids(pipeline, text)
    style = load_voice(pipeline, voice, phonemes)
    speed = torch.IntTensor([1])

    return input_ids, style, speed


def inference_onnx(model, output):
    onnx_file = output + "/" + "kokoro.onnx"
    session = ort.InferenceSession(onnx_file)

    input_ids, style, speed = load_sample(model)

    outputs = session.run(
        None,
        {
            "input_ids": input_ids.numpy(),
            "style": style.numpy(),
            "speed": speed.numpy(),
        },
    )

    output = torch.from_numpy(outputs[0])
    print(f"output: {output.shape}")
    print(output)

    audio = output.numpy()
    sd.play(audio, 24000)
    sd.wait()


def check_model(model):
    input_ids, style, speed = load_sample(model)
    output, duration = model(input_ids, style, speed)

    print(f"output: {output.shape}")
    print(f"duration: {duration.shape}")
    print(output)

    audio = output.numpy()
    sd.play(audio, 24000)
    sd.wait()


if __name__ == "__main__":
    parser = argparse.ArgumentParser("Export kokoro Model to ONNX", add_help=True)
    parser.add_argument(
        "--inference", "-t", help="test kokoro.onnx model", action="store_true"
    )
    parser.add_argument("--check", "-m", help="check kokoro model", action="store_true")
    parser.add_argument(
        "--config_file",
        "-c",
        type=str,
        default="checkpoints/config.json",
        help="path to config file",
    )
    parser.add_argument(
        "--checkpoint_path",
        "-p",
        type=str,
        default="checkpoints/kokoro-v1_0.pth",
        help="path to checkpoint file",
    )
    parser.add_argument(
        "--output_dir", "-o", type=str, default="onnx", help="output directory"
    )

    args = parser.parse_args()

    # cfg
    config_file = args.config_file  # change the path of the model config file
    checkpoint_path = args.checkpoint_path  # change the path of the model
    output_dir = args.output_dir

    # make dir
    os.makedirs(output_dir, exist_ok=True)

    kmodel = KModel(config=config_file, model=checkpoint_path, disable_complex=True)
    model = KModelForONNX(kmodel).eval()

    if args.inference:
        inference_onnx(model, output_dir)
    elif args.check:
        check_model(model)
    else:
        export_onnx(model, output_dir)
